// STL
#include <cmath>     // std::log(), std::exp(), std::sqrt()
#include <stdexcept> // std::runtime_error()
#include <string>    // std::to_string()
#include <vector>    // std::vector<>

// jScience
#include "jScience/linalg.hpp"         // Matrix<>, Vector<>, dot_product(), transpose()
#include "jScience/physics/consts.hpp" // PI


// NOTE: for energy function and free energy formulas, see: http://wiki.statsxx.com/view/Restricted_Boltzmann_machine

// NOTE: in each of the following, constant terms to the free energy are not calculated (as they are not informative)

// NOTE: the softplus free energy is approximated as the ReLU free energy (the former requires complicated series expansions for the antiderivative)

inline double machine_learning::RBM::free_energy(
                                                 const Vector<double> &v
                                                 ) const
{
    double visible_term = this->free_energy_visible(
                                                    v
                                                    );
    
    double hidden_term = this->free_energy_hidden(
                                                  v
                                                  );
    
    // NOTE: as discussed above, the constant term is not calculated (included)
    
//    double const_term = this->free_energy_const();
    
    return (visible_term + hidden_term);
}

//
// DESC:
//
inline std::vector<double> machine_learning::RBM::free_energy(
                                                              const Matrix<double> &V
                                                              ) const
{
    std::vector<double> free_energies;
    
    for(auto i = 0; i < V.size(0); ++i)
    {
        double free_energy = this->free_energy(
                                               V.row(i)
                                               );
        
        free_energies.push_back(free_energy);
    }
    
    return free_energies;
}


//
// DESC:
//
inline double machine_learning::RBM::free_energy_visible(
                                                         const Vector<double> &v
                                                         ) const
{
    double free_energy = -dot_product(
                                      this->a,
                                      v
                                      );
    
    switch(this->vtype)
    {
        case 0:
            break;
        case 1:
            free_energy += 0.5*dot_product(
                                           v,
                                           v
                                           );
            break;
        case 2:
        case 3:
            for(auto i = 0; i < v.size(); ++i)
            {
                if(v(i) > 0.)
                {
                    free_energy += v(i)*v(i);
                }
            }
            break;
        default:
            throw std::runtime_error("error: RBM type vtype=" + std::to_string(vtype) + " not recognized (or implemented) in free_energy_visible()");
            break;
    }
    
    return free_energy;
}

//
// DESC:
//
inline double machine_learning::RBM::free_energy_hidden(
                                                        const Vector<double> &v
                                                        ) const
{
    double free_energy;
    
    Vector<double> x = transpose(this->W)*v + this->b;
    
    switch(this->htype)
    {
        case 0:
            // NOTE: repeated multiplication followed by one call to std::log() should be faster than repeated addition of calls to std::log()
            free_energy = 1.;
            for(auto i = 0; i < x.size(); ++i)
            {
                free_energy *= (1. + std::exp(x(i)));
            }
            free_energy = -std::log(free_energy);
            break;
        case 1:
            free_energy = -0.5*dot_product(
                                           x,
                                           x
                                           );
            break;
        case 2:
        case 3:
            free_energy = -0.25*dot_product(
                                            x,
                                            x
                                            );
            break;
        default:
            throw std::runtime_error("error: RBM type htype=" + std::to_string(htype) + " not recognized (or implemented) in free_energy_hidden()");
            break;
    }
    
    return free_energy;
}

//
// DESC:
//
inline double machine_learning::RBM::free_energy_const() const
{
    double free_energy;
    
    switch(this->htype)
    {
        case 0:
            free_energy = 0.;
            break;
        case 1:
            free_energy = -nh*std::log(2*std::sqrt(PI/2));
            break;
        case 2:
        case 3:
            free_energy = -nh*std::log(PI/2);
            break;
        default:
            throw std::runtime_error("error: RBM type htype=" + std::to_string(htype) + " not recognized (or implemented) in free_energy_const()");
            break;
    }
    
    return free_energy;
}